import logging
from contextlib import suppress

import torch
import torch.nn.functional as F
from tqdm import tqdm

# from open_clip import tokenize
from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
from .imagenet_zeroshot_data import imagenet_classnames_zh, openai_imagenet_template_zh


def zero_shot_classifier(model, classnames, templates, args, amp=True):
    autocast = suppress #torch.cuda.amp.autocast if amp else suppress 
    with torch.no_grad():
        from transformers import XLMRobertaTokenizer
        tokenize =XLMRobertaTokenizer.from_pretrained('/share/project/xlm1024-33m-cls')

        zeroshot_weights = []

        for classname in tqdm(classnames):

            texts = [template(classname) for template in templates]  # format with class
            texts = tokenize(texts, truncation=True, max_length=77, 
                                         padding='max_length', return_tensors="pt")
            for key in texts:
                texts[key] = texts[key].to(args.device)
            with autocast():
                if args.distributed and not args.horovod:
                    class_embeddings = model.module.encode_text(texts)
                else:
                    class_embeddings = model.encode_text(texts)
                class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
                class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)        
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
    return zeroshot_weights


def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]


def run(model, classifier, dataloader, args):
    autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
    with torch.no_grad():#,autocast():
        top1, top5, n = 0., 0., 0.
        for images, target in tqdm(dataloader, unit_scale=args.batch_size):
            # predict
            target = target.to(args.device)
            if args.distributed and not args.horovod:
                image_features = model.module.encode_image(images.to(args.device))
            else:
                image_features = model.encode_image(images.to(args.device))
            image_features = F.normalize(image_features, dim=-1)
            logits = 100. * image_features @ classifier

            # measure accuracy
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1 += acc1
            top5 += acc5
            n += images.size(0)

    top1 = (top1 / n)
    top5 = (top5 / n)
    return top1, top5


def zero_shot_eval(model, data, epoch, args):
    if 'imagenet-val' not in data and 'imagenet-v2' not in data:
        return {}
    if args.zeroshot_frequency == 0:
        return {}
    if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
        return {}

    logging.info('Starting zero-shot imagenet.')

    logging.info('Building zero-shot classifier')
    
    classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
    if args.language_zh:
        classifier_zh = zero_shot_classifier(model, imagenet_classnames_zh, openai_imagenet_template_zh, args)
    logging.info('Using classifier')
    results = {}
    if 'imagenet-val' in data:
        top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
        results['imagenet-zeroshot-val-top1'] = top1
        results['imagenet-zeroshot-val-top5'] = top5
        if args.language_zh:
            top1, top5 = run(model, classifier_zh, data['imagenet-val'].dataloader, args)
            results['imagenet-zeroshot-val-zh-top1'] = top1
            results['imagenet-zeroshot-val-zh-top5'] = top5
    if 'imagenet-v2' in data:
        top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
        results['imagenetv2-zeroshot-val-top1'] = top1
        results['imagenetv2-zeroshot-val-top5'] = top5
        

    logging.info('Finished zero-shot imagenet.')

    return results
